{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Reinforcement Learning (DQN) 教學\n",
    "=====================================\n",
    "This tutorial shows how to use PyTorch to train a Deep Q Learning (DQN) agent\n",
    "on the CartPole-v0 task from the `OpenAI Gym <https://gym.openai.com/>`__.\n",
    "\n",
    "**Task**\n",
    "\n",
    "The agent has to decide between two actions - moving the cart left or\n",
    "right - so that the pole attached to it stays upright. \n",
    "- Cartpole：在一輛小車上豎立一根桿子，將小車向左或向右移動（一個推或者拉的力），使得桿子盡量保持平衡不倒下。\n",
    "\n",
    "As the agent observes the current state of the environment and chooses\n",
    "an action, the environment *transitions* to a new state, and also\n",
    "returns a reward that indicates the consequences of the action. In this\n",
    "task, the environment terminates if the pole falls over too far.\n",
    "\n",
    "The CartPole task is designed so that the inputs to the agent are 4 real\n",
    "values representing the environment state (position, velocity, etc.).\n",
    "However, neural networks can solve the task purely by looking at the\n",
    "scene, so we'll use a patch of the screen centered on the cart as an\n",
    "input. Because of this, our results aren't directly comparable to the\n",
    "ones from the official leaderboard - our task is much harder.\n",
    "Unfortunately this does slow down the training, because we have to\n",
    "render all the frames.\n",
    "\n",
    "Strictly speaking, we will present the state as the difference between\n",
    "the current screen patch and the previous one. This will allow the agent\n",
    "to take the velocity of the pole into account from one image.\n",
    "\n",
    "- state(environment) : 描述強化學習問題中的外部環境，比如：Cartpole問題中桿子的角度，小車的位置、速度等\n",
    "- action : 在不同外部環境條件下採取的動作，比如：Cartpole問題中對於小車施加推或者拉的力。 action可以是離散的集合，也可以是連續的。\n",
    "- reward：對於agent/network作出的action後獲取的回報/評價。比如：Cartpole問題中如果施加的力可以繼續讓桿子保持平衡，那reward就+1\n",
    "- agent : 學習到的是如何在變化中的state和reward選擇action的能力"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "import gym\n",
    "import math\n",
    "import random\n",
    "import numpy as np\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "from collections import namedtuple\n",
    "from itertools import count\n",
    "from PIL import Image\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd import Variable\n",
    "import torchvision.transforms as T"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "env = gym.make('CartPole-v0').unwrapped\n",
    "# set up matplotlib\n",
    "is_ipython = 'inline' in matplotlib.get_backend()\n",
    "if is_ipython:\n",
    "    from IPython import display\n",
    "    \n",
    "plt.ion()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**gpu版本不支援的話註解掉**\n",
    "\n",
    "下方系統環境參數設定的部份也須註解掉入下所示：\n",
    "\n",
    "'''\n",
    "use_cuda = torch.cuda.is_available()\n",
    "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
    "Tensor = FloatTensor\n",
    "'''\n",
    "\n",
    "''''\n",
    "if use_cuda:\n",
    "    policy_net.cuda()\n",
    "    target_net.cuda()\n",
    "''''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# if gpu is to be used \n",
    "'''\n",
    "use_cuda = torch.cuda.is_available()\n",
    "FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor\n",
    "LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor\n",
    "ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor\n",
    "Tensor = FloatTensor\n",
    "'''\n",
    "# if gpu is not supported\n",
    "FloatTensor = torch.FloatTensor\n",
    "LongTensor = torch.LongTensor\n",
    "ByteTensor = torch.ByteTensor\n",
    "Tensor = FloatTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "Transition = namedtuple('Transition',\n",
    "                        ('state', 'action', 'next_state', 'reward'))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Replay Memory\n",
    "-------------\n",
    "\n",
    "We'll be using experience replay memory for training our DQN. It stores\n",
    "the transitions that the agent observes, allowing us to reuse this data\n",
    "later. By sampling from it randomly, the transitions that build up a\n",
    "batch are decorrelated. It has been shown that this greatly stabilizes\n",
    "and improves the DQN training procedure.\n",
    "\n",
    "For this, we're going to need two classses:\n",
    "\n",
    "-  ``Transition`` - a named tuple representing a single transition in\n",
    "   our environment\n",
    "-  ``ReplayMemory`` - a cyclic buffer of bounded size that holds the\n",
    "   transitions observed recently. It also implements a ``.sample()``\n",
    "   method for selecting a random batch of transitions for training.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class ReplayMemory(object):\n",
    "\n",
    "    def __init__(self, capacity):\n",
    "        self.capacity = capacity\n",
    "        self.memory = []\n",
    "        self.position = 0\n",
    "\n",
    "    def push(self, *args):\n",
    "        \"\"\"Saves a transition.\"\"\"\n",
    "        if len(self.memory) < self.capacity:#如果暫存的所有階段容量>記憶體容量 則清空\n",
    "            self.memory.append(None)\n",
    "        self.memory[self.position] = Transition(*args)\n",
    "        self.position = (self.position + 1) % self.capacity\n",
    "\n",
    "    def sample(self, batch_size):\n",
    "        return random.sample(self.memory, batch_size)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.memory)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "DQN algorithm\n",
    "-------------\n",
    "\n",
    "Our environment is deterministic, so all equations presented here are\n",
    "also formulated deterministically for the sake of simplicity. In the\n",
    "reinforcement learning literature, they would also contain expectations\n",
    "over stochastic transitions in the environment.\n",
    "\n",
    "Our aim will be to train a policy that tries to maximize the discounted,\n",
    "cumulative reward\n",
    "$R_{t_0} = \\sum_{t=t_0}^{\\infty} \\gamma^{t - t_0} r_t$, where\n",
    "$R_{t_0}$ is also known as the *return*. The discount,\n",
    "$\\gamma$, should be a constant between $0$ and $1$\n",
    "that ensures the sum converges. It makes rewards from the uncertain far\n",
    "future less important for our agent than the ones in the near future\n",
    "that it can be fairly confident about.\n",
    "\n",
    "\n",
    "The main idea behind Q-learning is that if we had a function\n",
    "$Q^*: State \\times Action \\rightarrow \\mathbb{R}$, that could tell\n",
    "us what our return would be, if we were to take an action in a given\n",
    "state, then we could easily construct a policy that maximizes our\n",
    "rewards:\n",
    "- 根據當前的Q值計算出一個最優的動作，這個policy π\n",
    "- 稱之為greedy policy，也就是\n",
    "\\begin{align}\\pi^*(s) = \\arg\\!\\max_a \\ Q^*(s, a)\\end{align}\n",
    "\n",
    "However, we don't know everything about the world, so we don't have\n",
    "access to $Q^*$. But, since neural networks are universal function\n",
    "approximators, we can simply create one and train it to resemble\n",
    "$Q^*$.\n",
    "- Q(s,a)≈Q′(s,a),我們需要更新參數w來使得Q函數逼近與最優的Q值。  <br>\n",
    "For our training update rule, we'll use a fact that every $Q$\n",
    "function for some policy obeys the Bellman equation:\n",
    "\n",
    "\\begin{align}Q^{\\pi}(s, a) = r + \\gamma Q^{\\pi}(s', \\pi(s'))\\end{align}\n",
    "\n",
    "The difference between the two sides of the equality is known as the\n",
    "temporal difference error, $\\delta$:\n",
    "\n",
    "\\begin{align}\\delta = Q(s, a) - (r + \\gamma \\max_a Q(s', a))\\end{align}\n",
    "\n",
    "To minimise this error, we will use the `Huber\n",
    "loss <https://en.wikipedia.org/wiki/Huber_loss>`__. The Huber loss acts\n",
    "like the mean squared error (定義目標函數objective function也就是loss function) when the error is small, but like the mean\n",
    "absolute error when the error is large - this makes it more robust to\n",
    "outliers when the estimates of $Q$ are very noisy. We calculate\n",
    "this over a batch of transitions, $B$, sampled from the replay\n",
    "memory:\n",
    "\n",
    "\\begin{align}\\mathcal{L} = \\frac{1}{|B|}\\sum_{(s, a, s', r) \\ \\in \\ B} \\mathcal{L}(\\delta)\\end{align}\n",
    "\n",
    "\\begin{align}\\text{where} \\quad \\mathcal{L}(\\delta) = \\begin{cases}\n",
    "     \\frac{1}{2}{\\delta^2}  & \\text{for } |\\delta| \\le 1, \\\\\n",
    "     |\\delta| - \\frac{1}{2} & \\text{otherwise.}\n",
    "   \\end{cases}\\end{align}\n",
    "\n",
    "***Q-network***\n",
    "^^^^^^^^^\n",
    "\n",
    "Our model will be a convolutional neural network that takes in the\n",
    "difference between the current and previous screen patches. It has two\n",
    "outputs, representing $Q(s, \\mathrm{left})$ and\n",
    "$Q(s, \\mathrm{right})$ (where $s$ is the input to the\n",
    "network). In effect, the network is trying to predict the *quality* of\n",
    "taking each action given the current input.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DQN(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(DQN, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 16, kernel_size=5, stride=2)\n",
    "        self.bn1 = nn.BatchNorm2d(16)\n",
    "        self.conv2 = nn.Conv2d(16, 32, kernel_size=5, stride=2)\n",
    "        self.bn2 = nn.BatchNorm2d(32)\n",
    "        self.conv3 = nn.Conv2d(32, 32, kernel_size=5, stride=2)\n",
    "        self.bn3 = nn.BatchNorm2d(32)\n",
    "        self.head = nn.Linear(448, 2)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.bn1(self.conv1(x)))\n",
    "        x = F.relu(self.bn2(self.conv2(x)))\n",
    "        x = F.relu(self.bn3(self.conv3(x)))\n",
    "        return self.head(x.view(x.size(0), -1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "resize = T.Compose([T.ToPILImage(),\n",
    "                    T.Resize(40, interpolation=Image.CUBIC),\n",
    "                    T.ToTensor()])\n",
    "# This is based on the code from gym.\n",
    "screen_width = 600\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**取得小車所在螢幕中的位置**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_cart_location():\n",
    "    world_width = env.x_threshold * 2\n",
    "    scale = screen_width / world_width\n",
    "    return int(env.state[0] * scale + screen_width / 2.0)  # MIDDLE OF CART\n",
    "\n",
    "\n",
    "def get_screen():\n",
    "    screen = env.render(mode='rgb_array').transpose(\n",
    "        (2, 0, 1))  # transpose into torch order (CHW)\n",
    "    # Strip off the top and bottom of the screen\n",
    "    screen = screen[:, 160:320]\n",
    "    view_width = 320\n",
    "    cart_location = get_cart_location()\n",
    "    if cart_location < view_width // 2:\n",
    "        slice_range = slice(view_width)\n",
    "    elif cart_location > (screen_width - view_width // 2):\n",
    "        slice_range = slice(-view_width, None)\n",
    "    else:\n",
    "        slice_range = slice(cart_location - view_width // 2,\n",
    "                            cart_location + view_width // 2)\n",
    "    # Strip off the edges, so that we have a square image centered on a cart\n",
    "    screen = screen[:, :, slice_range]\n",
    "    # Convert to float, rescare, convert to torch tensor\n",
    "    # (this doesn't require a copy)\n",
    "    screen = np.ascontiguousarray(screen, dtype=np.float32) / 255\n",
    "    screen = torch.from_numpy(screen)\n",
    "    # Resize, and add a batch dimension (BCHW)\n",
    "    return resize(screen).unsqueeze(0).type(Tensor)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, you can find an optimize_model function that performs a single step of the optimization. It first samples a batch, concatenates all the tensors into a single one, computes Q(st,at)Q(st,at) and V(st+1)=maxaQ(st+1,a)V(st+1)=maxaQ(st+1,a), and combines them into our loss. By defition we set V(s)=0V(s)=0 if ss is a terminal state. We also use a target network to compute V(st+1)V(st+1) for added stability. The target network has its weights kept frozen most of the time, but is updated with the policy network's weights every so often. This is usually a set number of steps but we shall use episodes for simplicity."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def select_action(state):\n",
    "    global steps_done\n",
    "    sample = random.random()\n",
    "    eps_threshold = EPS_END + (EPS_START - EPS_END) * \\\n",
    "        math.exp(-1. * steps_done / EPS_DECAY)\n",
    "    steps_done += 1\n",
    "    if sample > eps_threshold:\n",
    "        return policy_net(\n",
    "            Variable(state, volatile=True).type(FloatTensor)).data.max(1)[1].view(1, 1)\n",
    "    else:\n",
    "        return LongTensor([[random.randrange(2)]])\n",
    "\n",
    "\n",
    "episode_durations = []\n",
    "\n",
    "\n",
    "def plot_durations():\n",
    "    plt.figure(2)\n",
    "    plt.clf()\n",
    "    durations_t = torch.FloatTensor(episode_durations)\n",
    "    plt.title('Training...')\n",
    "    plt.xlabel('Episode')\n",
    "    plt.ylabel('Duration')\n",
    "    plt.plot(durations_t.numpy())\n",
    "    # Take 100 episode averages and plot them too\n",
    "    if len(durations_t) >= 100:\n",
    "        means = durations_t.unfold(0, 100, 1).mean(1).view(-1)\n",
    "        means = torch.cat((torch.zeros(99), means))\n",
    "        plt.plot(means.numpy())\n",
    "\n",
    "    plt.pause(0.001)  # pause a bit so that plots are updated\n",
    "    if is_ipython:\n",
    "        display.clear_output(wait=True)\n",
    "        display.display(plt.gcf())\n",
    "    \n",
    "def optimize_model():\n",
    "    if len(memory) < BATCH_SIZE:\n",
    "        return\n",
    "    transitions = memory.sample(BATCH_SIZE)\n",
    "    # Transpose the batch (see http://stackoverflow.com/a/19343/3343043 for\n",
    "    # detailed explanation).\n",
    "    batch = Transition(*zip(*transitions))\n",
    "\n",
    "    # Compute a mask of non-final states and concatenate the batch elements\n",
    "    non_final_mask = ByteTensor(tuple(map(lambda s: s is not None,\n",
    "                                          batch.next_state)))\n",
    "    non_final_next_states = Variable(torch.cat([s for s in batch.next_state\n",
    "                                                if s is not None]),\n",
    "                                     volatile=True)\n",
    "    state_batch = Variable(torch.cat(batch.state))\n",
    "    action_batch = Variable(torch.cat(batch.action))\n",
    "    reward_batch = Variable(torch.cat(batch.reward))\n",
    "\n",
    "    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the\n",
    "    # columns of actions taken\n",
    "    state_action_values = policy_net(state_batch).gather(1, action_batch)\n",
    "\n",
    "    \n",
    "    \n",
    "    # Compute V(s_{t+1}) for all next states.\n",
    "    next_state_values = Variable(torch.zeros(BATCH_SIZE).type(Tensor))\n",
    "    next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]\n",
    "    # Compute the expected Q values\n",
    "    # Qπ(s,a)=E[rt+1+λrt+2+λ2rt+3+...|s,a]=Es′[r+λQπ(s′,a′)|s,a] -> Q∗(s,a)=Es′[r+λmaxa′Q∗(s′,a′)|s,a]\n",
    "    expected_state_action_values = (next_state_values * GAMMA) + reward_batch\n",
    "    # Undo volatility (which was used to prevent unnecessary gradients)\n",
    "    expected_state_action_values = Variable(expected_state_action_values.data)\n",
    "\n",
    "    # Compute Huber loss\n",
    "    loss = F.smooth_l1_loss(state_action_values, expected_state_action_values)\n",
    "\n",
    "    # Optimize the model\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    for param in policy_net.parameters():\n",
    "        param.grad.data.clamp_(-1, 1)\n",
    "    optimizer.step()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**開始訓練**\n",
    "從隨機的策略開始，使用隨機的策略進行試驗，就可以得到一系列的狀態,動作和反饋：\n",
    "- 初始化Q(s,a),∀s∈S,a∈A(s), 任意的數值，並且Q(terminal−state,⋅)=0  <br>\n",
    " 重複（對每一節episode）:  <br>\n",
    "    初始化 狀態S   <br>\n",
    "    重複（對episode中的每一步）：  <br>\n",
    "        使用某一個policy比如（ϵ−greedy)根據狀態S選取一個動作執行  <br>\n",
    "        執行完動作後，觀察reward和新的狀態S′  <br>\n",
    "        Q(St,At)←Q(St,At)+α(Rt+1+λmaxaQ(St+1,a)−Q(St,At))  <br>\n",
    "        S←S′  <br>\n",
    "     循環直到S終止  <br>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**系統環境參數設定**\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAADWCAYAAADBwHkCAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvNQv5yAAAFKhJREFUeJzt3X2wXHV9x/H3JzcPBAiQkIQGEr2KkScHLogBB2shgI1YBae2SlsbLFZtcYQRlAdnFK2dypSnzthBRRAUBTWKYIpKDKGWVoEEAgYCJMQoIZeEIJkEwZCHb/84vwtn772b3Xv36dxzP6+ZM7u/c86e/ezZe7979re756eIwMzMRr4xnQ5gZmbN4YJuZlYSLuhmZiXhgm5mVhIu6GZmJeGCbmZWEi7o1naSzpJ0T6dzFImkbkkhaWyns9jI5YJeMpLWSnpJ0gu56cudztVpkk6UtK6F279U0k2t2r5ZPXw0UE7vjoifdzrESCNpbETs6HSOVijzY7NX+Qh9FJF0jaQFufZlkhYrM1nSQknPSno+XZ+ZW/duSV+U9H/pqP/HkvaX9G1JWyTdL6k7t35I+oSkNZI2Sfp3SYP+vUk6VNIiSb+X9Likv97NY9hX0nWSeiU9nTJ11Xh8ewE/AQ7MvWs5MB1VL5B0k6QtwFmS5kj6paTN6T6+LGl8bptH5LJukHSJpHnAJcD707YfqiNrl6TL075ZA7yrxnN3YdrG1rSPTs5t5xJJT6ZlyyTNyj0H50haBayqta8lTUiZfpce21ckTUzLTpS0TtL5kjamx/Sh3WW2DogITyWagLXAKVWW7Qk8AZwF/CmwCZiZlu0P/GVaZxLwfeBHudveDawGDgb2BR5N2zqF7J3eN4Fv5NYPYAkwBXhNWvfDadlZwD3p+l7AU8CH0naOSbmOqPIYfgR8Nd1uOnAf8NE6Ht+JwLp+27oU2A6cQXZwMxF4M3B8ytINrATOS+tPAnqB84E9Uvu43LZuGkLWjwGPAbPSPlqS9tnYQR7zIWkfHZja3cDB6fqngF+ndQQcBeyfew4Wpe1PrLWvgauB29P6k4AfA/+W2387gC8A44DTgBeByZ3+m/eU+1vpdABPTX5Cs4L+ArA5N/1jbvkc4PfAb4Ezd7OdHuD5XPtu4DO59hXAT3LtdwPLc+0A5uXa/wwsTtfP4tWC/n7gf/rd91eBzw2S6QBgGzAxN+9MYEmtx0f1gv6LGvvzPODW3H09WGW9S8kV9FpZgbuAj+WWvYPqBf0NwEayF89x/ZY9DpxeJVMAc3Ptqvua7MXgD6QXirTsrcBvcvvvpXy+lOn4Tv/Ne3p1ch96OZ0RVfrQI+K+9BZ/OvC9vvmS9gSuAuYBk9PsSZK6ImJnam/IbeqlQdp797u7p3LXfwscOEik1wLHSdqcmzcW+FaVdccBvZL65o3J30+1x7cb+YxIeiNwJXAs2RH/WGBZWjwLeLKObdaT9UAG7p9BRcRqSeeRvWgcIelnwCcjYn0dmfL3sbt9PY3s8S7L5RXQlVv3uajsh3+Rgc+5dZD70EcZSecAE4D1wKdzi84ne9t+XETsA7y97yYN3N2s3PXXpPvs7yngvyNiv9y0d0T8U5V1twFTc+vuExFH9K2wm8dX7bSi/edfQ9YVMjvth0t4dR88RdblVM92amXtZeD+qSoivhMRbyMrygFcVkem/rl2t683kb0oH5Fbtm9EuGCPIC7oo0g6+vwi8HfAB4FPS+pJiyeR/UNvljSF7G14oz6VPmydBZwLfHeQdRYCb5T0QUnj0vQWSYf1XzEieoE7gSsk7SNpjKSDJf1ZHY9vA7C/pH1rZJ4EbAFekHQokH9hWQj8iaTz0geIkyQdl9t+d98Hv7Wykr17+ISkmZImAxdVCyTpEElzJU0A/kj2PPW9a/o68C+SZitzpKT9q2yq6r6OiF3AtcBVkqan+z1I0p/X2F9WIC7o5fRjVX4P/VZlP1i5CbgsIh6KiFVkR5/fSoXiarIPzjYBvwJ+2oQct5F1VywH/gu4rv8KEbGVrP/4A2RH1c+QHX1OqLLNvwfGk30o+zywAJhR6/FFxGPAzcCa9A2Wwbp/AC4A/gbYSlbgXnkRSllPJfu84Bmyb46clBZ/P10+J+mB3WVNy64FfgY8BDwA/LBKHtK++BLZc/MMWXfSJWnZlWQvDneSvRBdR/Y8DlDHvr6Q7IPvX6Vv/fyc7F2bjRCK8AAX1nySgqzbYnWns5iNFj5CNzMrCRd0M7OScJeLmVlJNHSELmle+vnwaklVP6U3M7PWG/YRejonxRNkn/qvA+4n+2Xeo9VuM3Xq1Oju7h7W/ZmZjVbLli3bFBHTaq3XyC9F5wCrI2INgKRbgNPJvqI1qO7ubpYuXdrAXZqZjT6Sqv6SOK+RLpeDqPxZ8bo0r3+Qj0haKmnps88+28DdmZnZ7jRS0Af7SfiA/puI+FpEHBsRx06bVvMdg5mZDVMjBX0dleeimMng5+owM7M2aKSg3w/MlvQ6ZQMAfIDsXMpmZtYBw/5QNCJ2SPo42fkouoDrI+KRpiUzM7Mhaeh86BFxB3BHk7KYmVkDPMCFjU79fn+RnT22ksZ0DZhnVmQ+l4uZWUm4oJuZlYQLuplZSbigm5mVhD8UtVFp+x+3VrTXLPrqgHVi186K9vQ3za1oT3nDnOYHM2uAj9DNzErCBd3MrCRc0M3MSsJ96DYqbdv8TEV7y9MrB6wTOyv70Gcc866WZjJrlI/QzcxKwgXdzKwkGupykbQW2ArsBHZExLHNCGVmZkPXjD70kyJiUxO2Y9Y2/U/GJQ18s6qxlfPGjB3X0kxmjXKXi5lZSTRa0AO4U9IySR8ZbAUPEm1m1h6NFvQTIuIY4J3AOZLe3n8FDxJtZtYeDRX0iFifLjcCtwI+uYWZWYcMu6BL2kvSpL7rwDuAFc0KZmZmQ9PIt1wOAG6V1Led70TET5uSyszMhmzYBT0i1gBHNTGLmZk1wF9bNDMrCRd0M7OScEE3MysJF3Qzs5JwQTczKwkXdDOzknBBNzMrCRd0M7OScEE3MysJDxJto9JgA1oMFP2aMfhqZgXhI3Qzs5JwQTczK4maBV3S9ZI2SlqRmzdF0iJJq9Ll5NbGNDOzWuo5Qr8BmNdv3kXA4oiYDSxObbMRY9vWTRVT7No5YOoaN7FiGrfnvhWTWdHULOgR8Qvg9/1mnw7cmK7fCJzR5FxmZjZEw+1DPyAiegHS5fRqK3qQaDOz9mj5h6IeJNrMrD2G+z30DZJmRESvpBnAxmaGMmu1bVs2VbRj184B64wZt0dFe6z7za3ghnuEfjswP12fD9zWnDhmZjZc9Xxt8Wbgl8AhktZJOhv4EnCqpFXAqaltZmYdVLPLJSLOrLLo5CZnMTOzBvhcLjYq+VwuVkb+6b+ZWUm4oJuZlYQLuplZSbigm5mVhAu6mVlJuKCbmZWEC7qZWUm4oJuZlYQLuplZSbigm5mVhAu6mVlJDHeQ6EslPS1peZpOa21MMzOrZbiDRANcFRE9abqjubHMzGyohjtItJmZFUwjfegfl/Rw6pKZXG0lDxJtZtYewy3o1wAHAz1AL3BFtRU9SLSZWXsMq6BHxIaI2BkRu4BrgTnNjWXWYlLlNKjoN5kV27AKuqQZueZ7gRXV1jUzs/aoOQRdGiT6RGCqpHXA54ATJfWQHbasBT7awoxmZlaH4Q4SfV0LspiZWQM8SLSNSi9uWldznQmTple0u8ZPbFUcs6bwT//NzErCBd3MrCRc0M3MSsIF3cysJPyhqI1KO7f9oeY6Y8bvUdHWmK5WxTFrCh+hm5mVhAu6mVlJuKCbmZWE+9BtdKp6Qq6c8Am5bGTxEbqZWUm4oJuZlUQ9g0TPkrRE0kpJj0g6N82fImmRpFXpsuqoRWZm1nr1HKHvAM6PiMOA44FzJB0OXAQsjojZwOLUNjOzDqlnkOjeiHggXd8KrAQOAk4Hbkyr3Qic0aqQZmZW25D60CV1A0cD9wIHREQvZEUfmF7lNh4k2sysDeou6JL2Bn4AnBcRW+q9nQeJNjNrj7oKuqRxZMX82xHxwzR7Q9/YoulyY2simplZPer5lovIhpxbGRFX5hbdDsxP1+cDtzU/npmZ1aueX4qeAHwQ+LWk5WneJcCXgO9JOhv4HfBXrYloZmb1qGeQ6HuAar+TPrm5cczMbLh8LhcbnXwuFysh//TfzKwkXNDNzErCBd3MrCRc0M3MSsIfitqoELt2VrR3vfzHmrfpmrBXq+KYtYSP0M3MSsIF3cysJFzQzcxKwn3oNirsfPmliva2rbXPJbfn1JmtimPWEj5CNzMrCRd0M7OSaGSQ6EslPS1peZpOa31cMzOrpp4+9L5Boh+QNAlYJmlRWnZVRFzeunhmreKTc1n51HP63F6gb+zQrZL6Bok2M7MCaWSQaICPS3pY0vWSJle5jQeJNjNrg0YGib4GOBjoITuCv2Kw23mQaDOz9hj2INERsSEidkbELuBaYE7rYpqZWS3DHiRa0ozcau8FVjQ/npmZ1auRQaLPlNQDBLAW+GhLEpqZWV0aGST6jubHMTOz4fIvRc3MSsIF3cysJFzQzcxKwgXdzKwkXNDNzErCBd3MrCRc0M3MSsIF3cysJFzQzcxKwoNE26gg9T928QAXVj4+QjczKwkXdDOzkqjn9Ll7SLpP0kNpkOjPp/mvk3SvpFWSvitpfOvjmplZNfX0oW8D5kbEC2mgi3sk/QT4JNkg0bdI+gpwNtkoRmaFs+vlFyrb21+saI8ZM/DYZs/JB7Q0k1mz1TxCj0zff8O4NAUwF1iQ5t8InNGShGZmVpd6h6DrSoNbbAQWAU8CmyNiR1plHXBQldt6kGgzszaoq6CnsUN7gJlkY4ceNthqVW7rQaLNzNpgSN9Dj4jNku4Gjgf2kzQ2HaXPBNa3IJ+NQg8++GBF+4ILLmh4m2+YPqGi/eGTDq5oa9ykAbe58LNfqGiveualhnNcfvnlFe2jjz664W2a9annWy7TJO2Xrk8ETgFWAkuA96XV5gO3tSqkmZnVVs8R+gzgRkldZC8A34uIhZIeBW6R9EXgQeC6FuY0M7Ma6hkk+mFgwPvCiFhD1p9uZmYF4HO5WOE899xzFe277rqr4W0+/druivZhR15Y0d5J14Db/PyeD1W0n/zd6oZz9H9sZs3kn/6bmZWEC7qZWUm4oJuZlYQLuplZSfhDUSucsWOb/2c5ZtzeFe1t7Fe5fMy4gTnG79P0HK14bGZ9fIRuZlYSLuhmZiXhgm5mVhJt7dDbvn07vb297bxLG4E2bdrU9G0+vf6JivY3b/iHivbh3dMH3OaFzauanqP/Y/P/gzWTj9DNzErCBd3MrCQaGST6Bkm/kbQ8TT2tj2tmZtU0Mkg0wKciYsFublthx44deBg6q2Xz5s1N3+aWF1+uaD/6xAP92k2/y0H1f2z+f7Bmquf0uQEMNki0mZkVyLAGiY6Ie9Oif5X0sKSrJE2octtXBol+/vnnmxTbzMz6G9Yg0ZLeBFwMHAq8BZgCXFjltq8MEj158uQmxTYzs/6GO0j0vIjoG+12m6RvADVH8p04cSJHHnnk0FPaqFLmd3KzZ8+uaPv/wZppuINEPyZpRpon4AxgRSuDmpnZ7jUySPRdkqYBApYDH2thTjMzq6GRQaLntiSRmZkNi0/ObIWzffv2TkdomTI/Nus8//TfzKwkXNDNzErCBd3MrCRc0M3MSsIfilrhTJ06taJ9yimndChJ8/V/bGbN5CN0M7OScEE3MysJF3Qzs5JwH7oVTk9P5eBXixYt6lASs5HFR+hmZiXhgm5mVhIu6GZmJaFsyNA23Zn0LPBbYCqwqW13PHzO2VwjIedIyAjO2WxFz/naiJhWa6W2FvRX7lRaGhHHtv2Oh8g5m2sk5BwJGcE5m22k5KzFXS5mZiXhgm5mVhKdKuhf69D9DpVzNtdIyDkSMoJzNttIyblbHelDNzOz5nOXi5lZSbigm5mVRNsLuqR5kh6XtFrSRe2+/2okXS9po6QVuXlTJC2StCpdTu5wxlmSlkhaKekRSecWNOceku6T9FDK+fk0/3WS7k05vytpfCdz9pHUJelBSQtTu3A5Ja2V9GtJyyUtTfMK9bynTPtJWiDpsfR3+tYi5ZR0SNqHfdMWSecVKWMj2lrQJXUB/wm8EzgcOFPS4e3MsBs3APP6zbsIWBwRs4HFqd1JO4DzI+Iw4HjgnLT/ipZzGzA3Io4CeoB5ko4HLgOuSjmfB87uYMa8c4GVuXZRc54UET2570sX7XkH+A/gpxFxKHAU2X4tTM6IeDztwx7gzcCLwK1FytiQiGjbBLwV+FmufTFwcTsz1MjXDazItR8HZqTrM4DHO52xX97bgFOLnBPYE3gAOI7sl3hjB/tb6GC+mWT/wHOBhYAKmnMtMLXfvEI978A+wG9IX7Yoas5crncA/1vkjEOd2t3lchDwVK69Ls0rqgMiohcgXU7vcJ5XSOoGjgbupYA5UzfGcmAjsAh4EtgcETvSKkV57q8GPg3sSu39KWbOAO6UtEzSR9K8oj3vrweeBb6RurC+LmkvipezzweAm9P1omYcknYXdA0yz9+bHCJJewM/AM6LiC2dzjOYiNgZ2dvamcAc4LDBVmtvqkqS/gLYGBHL8rMHWbUIf6MnRMQxZN2V50h6e6cDDWIscAxwTUQcDfyBgnZdpM9F3gN8v9NZmqndBX0dMCvXngmsb3OGodggaQZAutzY4TxIGkdWzL8dET9MswuXs09EbAbuJuvz309S36AqRXjuTwDeI2ktcAtZt8vVFC8nEbE+XW4k6/OdQ/Ge93XAuoi4N7UXkBX4ouWE7IXxgYjYkNpFzDhk7S7o9wOz07cIxpO95bm9zRmG4nZgfro+n6zPumMkCbgOWBkRV+YWFS3nNEn7pesTgVPIPhxbArwvrdbxnBFxcUTMjIhusr/FuyLibylYTkl7SZrUd52s73cFBXveI+IZ4ClJh6RZJwOPUrCcyZm82t0Cxcw4dB34IOI04AmyPtXPdPpDhFyum4FeYDvZkcbZZP2pi4FV6XJKhzO+jezt/8PA8jSdVsCcRwIPppwrgM+m+a8H7gNWk73VndDp5z2X+URgYRFzpjwPpemRvv+boj3vKVMPsDQ99z8CJhctJ9kH9c8B++bmFSrjcCf/9N/MrCT8S1Ezs5JwQTczKwkXdDOzknBBNzMrCRd0M7OScEE3MysJF3Qzs5L4fxm/PzsB6/QKAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7fcc0c137860>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "env.reset()\n",
    "plt.figure()\n",
    "plt.imshow(get_screen().cpu().squeeze(0).permute(1, 2, 0).numpy(), interpolation='none')\n",
    "plt.title('Example extracted screen')\n",
    "plt.show()\n",
    "           \n",
    "BATCH_SIZE = 128\n",
    "GAMMA = 0.999\n",
    "EPS_START = 0.9\n",
    "EPS_END = 0.05\n",
    "EPS_DECAY = 200\n",
    "TARGET_UPDATE = 10\n",
    "\n",
    "policy_net = DQN()\n",
    "target_net = DQN()\n",
    "target_net.load_state_dict(policy_net.state_dict())\n",
    "target_net.eval()\n",
    "\n",
    "# if gpu is not supported , command out\n",
    "'''\n",
    "if use_cuda:\n",
    "    policy_net.cuda()\n",
    "    target_net.cuda()\n",
    "'''\n",
    "optimizer = optim.RMSprop(policy_net.parameters())\n",
    "memory = ReplayMemory(10000)\n",
    "\n",
    "\n",
    "steps_done = 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7fcbfe238dd8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Complete\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7fcbfe238dd8>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "\n",
    "#num_episodes = 50\n",
    "num_episodes = 10\n",
    "for i_episode in range(num_episodes):\n",
    "    # Initialize the environment and state\n",
    "    env.reset()\n",
    "    last_screen = get_screen()\n",
    "    current_screen = get_screen()\n",
    "    state = current_screen - last_screen\n",
    "    for t in count():\n",
    "        # Select and perform an action\n",
    "        action = select_action(state)\n",
    "        _, reward, done, _ = env.step(action[0, 0])\n",
    "        reward = Tensor([reward])\n",
    "\n",
    "        # Observe new state\n",
    "        last_screen = current_screen\n",
    "        current_screen = get_screen()\n",
    "        if not done:\n",
    "            next_state = current_screen - last_screen\n",
    "        else:\n",
    "            next_state = None\n",
    "\n",
    "        # Store the transition in memory\n",
    "        memory.push(state, action, next_state, reward)\n",
    "\n",
    "        # Move to the next state\n",
    "        state = next_state\n",
    "\n",
    "        # Perform one step of the optimization (on the target network)memory\n",
    "        \n",
    "        \n",
    "        optimize_model()\n",
    "        if done:\n",
    "            episode_durations.append(t + 1)\n",
    "            plot_durations()\n",
    "            break\n",
    "    # Update the target network\n",
    "    if i_episode % TARGET_UPDATE == 0:\n",
    "        target_net.load_state_dict(policy_net.state_dict())\n",
    "\n",
    "print('Complete')\n",
    "env.close()\n",
    "plt.ioff()\n",
    "plt.show()\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
